{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/05-trainer-flags-overview.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "goRmGIRI5cfC"
   },
   "source": [
    "# Introduction to Lightning Flags ⚡🚩\n",
    "\n",
    "In this notebook, we'll go over the flags available in the `Trainer` object. Note that not everything will work in the Colab environment (multi-gpu, etc). This notebook accompanies the Trainer videos we'll be putting out.\n",
    "\n",
    "---\n",
    "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
    "  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
    "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jKj5lgdr5j48"
   },
   "source": [
    "--- \n",
    "### Setup  \n",
    "First thing first, we need to install Lightning. Simply ```pip install pytorch-lightning```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UGjilEHk4vb7"
   },
   "outputs": [],
   "source": [
    "! pip install pytorch-lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zaVUShmQ5n8Y"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from argparse import ArgumentParser\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import random_split\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.metrics.functional import accuracy\n",
    "\n",
    "from torchvision.datasets.mnist import MNIST\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6tgkS8IYZwY_"
   },
   "outputs": [],
   "source": [
    "# ------------\n",
    "# data\n",
    "# ------------\n",
    "pl.seed_everything(1234)\n",
    "batch_size = 32\n",
    "\n",
    "# Init DataLoader from MNIST Dataset\n",
    "\n",
    "dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
    "mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())\n",
    "mnist_train, mnist_val = random_split(dataset, [55000, 5000])\n",
    "\n",
    "train_loader = DataLoader(mnist_train, batch_size=batch_size)\n",
    "val_loader = DataLoader(mnist_val, batch_size=batch_size)\n",
    "test_loader = DataLoader(mnist_test, batch_size=batch_size)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gEulmrbxwaYL"
   },
   "source": [
    "### Simple AutoEncoder Model\n",
    "\n",
    "Were gonna define a simple Lightning model so we can play with all the settings of the Lightning Trainer.\n",
    "\n",
    "LightningModule is simply pure Pytorch reorganized into hooks, that represents all the steps in the training process.\n",
    "\n",
    "You can use LightningModule hooks to control every part of your model, but for the purpose of this video we will use a very simple MNIST classifier, a model that takes 28*28 grayscale images of hand written images, and can predict the digit between 0-9.\n",
    "\n",
    "The LightningModule can encompass a single model, like an image classifier, or a deep learning system composed of multiple models, like this auto encoder that contains an encoder and a decoder.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "x-34xKCI40yW"
   },
   "outputs": [],
   "source": [
    "class LitAutoEncoder(pl.LightningModule):\n",
    "\n",
    "    def __init__(self, batch_size=32, lr=1e-3):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(28 * 28, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 3)\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(3, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 28 * 28)\n",
    "        )\n",
    "        self.batch_size=batch_size\n",
    "        self.learning_rate=lr\n",
    "\n",
    "    def forward(self, x):\n",
    "        # in lightning, forward defines the prediction/inference actions\n",
    "        embedding = self.encoder(x)\n",
    "        return embedding\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        z = self.encoder(x)\n",
    "        x_hat = self.decoder(z)\n",
    "        loss = F.mse_loss(x_hat, x)\n",
    "        self.log('train_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        z = self.encoder(x)\n",
    "        x_hat = self.decoder(z)\n",
    "        loss = F.mse_loss(x_hat, x)\n",
    "        self.log('val_loss', loss)\n",
    "        \n",
    "    def test_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        z = self.encoder(x)\n",
    "        x_hat = self.decoder(z)\n",
    "        loss = F.mse_loss(x_hat, x)\n",
    "        self.log('test_loss', loss)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "        return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VbxcRCrxiYly"
   },
   "source": [
    "You'll notice the LightningModule doesn't have epoch and batch loops, we're not calling model.train() and model.eval(), and no mentions of CUDA or hardware. That's because it is all automated by the Lightning Trainer. All the engineering boilerplate is automated by the trainer: \n",
    "\n",
    "*  Training loops\n",
    "*  Evaluation and test loops\n",
    "*  Calling model.train(), model.eval(), no_grad at the right time\n",
    "*  CUDA or to_device calls\n",
    "\n",
    "It also allows you to train your models on different hardware like GPUs and TPUs without changing your code!\n",
    "\n",
    "\n",
    "### To use the lightning trainer simply:\n",
    "\n",
    "1. init your LightningModule and datasets\n",
    "\n",
    "2. init lightning trainer\n",
    "\n",
    "3. call trainer.fit\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HOk9c4_35FKg"
   },
   "outputs": [],
   "source": [
    "#####################\n",
    "# 1. Init Model\n",
    "#####################\n",
    "\n",
    "model = LitAutoEncoder()\n",
    "\n",
    "#####################\n",
    "# 2. Init Trainer\n",
    "#####################\n",
    "\n",
    "# these 2 flags are explained in the later sections...but for short explanation:\n",
    "# - progress_bar_refresh_rate: limits refresh rate of tqdm progress bar so Colab doesn't freak out\n",
    "# - max_epochs: only run 2 epochs instead of default of 1000\n",
    "trainer = pl.Trainer(progress_bar_refresh_rate=20, max_epochs=2)\n",
    "\n",
    "#####################\n",
    "# 3. Train\n",
    "#####################\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3meDako-Qa_6"
   },
   "source": [
    "Our model is training just like that, using the Lightning defaults. The beauty of Lightning is that everything is easily configurable.\n",
    "In our next videos were going to show you all the ways you can control your Trainer to do things like controlling your training, validation and test loops, running on GPUs and TPUs, checkpointing, early stopping, and a lot more.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z_Wry2MckQkI"
   },
   "source": [
    "# Training loop and eval loop Flags"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0MkI1xB2vsLj"
   },
   "source": [
    "\n",
    "To really scale up your networks, you can use accelerators like GPUs. GPUs or Graphical Processing Units, parallelize matrix multiplications which enable speed ups of at least 100x over training on CPUs.\n",
    "\n",
    "Let's say you have a machine with 8 GPUs on it. You can set this flag to 1, 4, or 8 GPUs and lightning will automatically distribute your training for you.\n",
    "\n",
    "```\n",
    "trainer = pl.Trainer(gpus=1)\n",
    "```\n",
    "\n",
    "---------\n",
    "\n",
    "Lightning makes your code hardware agnostic... This means, you can switch between CPUs, GPUs without code changes.\n",
    "\n",
    "However, it requires forming good PyTorch habits:\n",
    "\n",
    "1. First, remove the .cuda() or .to() calls in your code.\n",
    "2. Second, when you initialize a new tensor, set the device=self.device in the call since every lightningModule knows what gpu index or TPU core it is on.\n",
    "\n",
    "You can also use type_as and or you can register the tensor as a buffer in your module’s __init__ method with register_buffer().\n",
    "\n",
    "```\n",
    "# before lightning\n",
    "def forward(self, x):\n",
    "    z = torch.Tensor(2, 3)\n",
    "    z = z.cuda(0)\n",
    "\n",
    "# with lightning\n",
    "def forward(self, x):\n",
    "    z = torch.Tensor(2, 3)\n",
    "    z = z.type_as(x, device=self.device)\n",
    "```\n",
    "\n",
    "\n",
    "```\n",
    "class LitModel(LightningModule):\n",
    "\n",
    "    def __init__(self):\n",
    "        ...\n",
    "        self.register_buffer(\"sigma\", torch.eye(3))\n",
    "        # you can now access self.sigma anywhere in your module\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hw6jJhhjvlSL"
   },
   "source": [
    "Lightning Trainer automates all the engineering boilerplate like iterating over epochs and batches, training eval and test loops, CUDA and to(device) calls, calling model.train and model.eval.\n",
    "\n",
    "You still have full control over the loops, by using the following trainer flags:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pT5-ETH9eUg6"
   },
   "source": [
    "## Calling validation steps\n",
    "Sometimes, training an epoch may be pretty fast, like minutes per epoch. In this case, you might not need to validate on every epoch. Instead, you can actually validate after a few epochs.\n",
    "\n",
    "Use `check_val_every_n_epoch` flag to control the frequency of validation step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z-EMVvKheu3D"
   },
   "outputs": [],
   "source": [
    "# run val loop every 10 training epochs\n",
    "trainer = pl.Trainer(check_val_every_n_epoch=10)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UOzZr9S2UcSO"
   },
   "source": [
    "## val_check_interval\n",
    "\n",
    "In some cases where your epoch is very long, you might want to check validation within an epoch.\n",
    "\n",
    "You can also run validation step within your training epochs, by setting `val_check_interval` flag.\n",
    "\n",
    "Set `val_check_interval` to a float between [0.0 to 1.0] to check your validation set within a training epoch. For example, setting it to 0.25 will check your validation set 4 times during a training epoch.\n",
    "\n",
    "Default is set to 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9kbUbvrUVLrT"
   },
   "outputs": [],
   "source": [
    "# check validation set 4 times during a training epoch\n",
    "trainer = pl.Trainer(val_check_interval=0.25)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Onm1gBsKVaw4"
   },
   "source": [
    "When you have iterable data sets, or when streaming data for production use cases, it is useful to check the validation set every number of steps. \n",
    "Set val_check_interval to an int:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "psn6DVb5Vi85"
   },
   "outputs": [],
   "source": [
    "# check validation set every 1000 training batches\n",
    "# use this when using iterableDataset and your dataset has no length\n",
    "# (ie: production cases with streaming data)\n",
    "trainer = pl.Trainer(val_check_interval=1000)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QkoYonrWkb7-"
   },
   "source": [
    "## num_sanity_val_steps \n",
    "\n",
    "You may have run into an issue, where you have a bug in your validation loop, but won't catch it until your training loop ends.\n",
    "\n",
    "and if your training loop takes hours or days, you will waste valuable compute.\n",
    "\n",
    "Instead, lightning automatically runs through 2 steps of validation in the beginning to catch these kinds of bugs up front.\n",
    "\n",
    "\n",
    "The `num_sanity_val_steps` flag can help you run n batches of validation before starting the training routine.\n",
    "\n",
    "You can set it to 0 to turn it off"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zOcT-ugSkiKW"
   },
   "outputs": [],
   "source": [
    "# turn it off\n",
    "trainer = pl.Trainer(num_sanity_val_steps=0)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zS0ob1ZmTw56"
   },
   "source": [
    "Set it to -1 to check all validation data before training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rzqvjA4UT263"
   },
   "outputs": [],
   "source": [
    "# check all validation data\n",
    "trainer = pl.Trainer(num_sanity_val_steps=-1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uMB41wq4T3Z2"
   },
   "source": [
    "Or use any arbitrary number of validation steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lGP78aQzT7VS"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(num_sanity_val_steps=10)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H-xaYRtd1rb-"
   },
   "source": [
    "## Limit train, validation, and test batches\n",
    "\n",
    "You can set limits on how much of training, validation and test dataset you want your model to check. This is useful if you have really large validation or tests sets, for debugging or testing something that happens at the end of an epoch.\n",
    "\n",
    "Set the flag to int to specify the number of batches to run\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XiK5cFKL1rcA"
   },
   "outputs": [],
   "source": [
    "# run for only 10 batches\n",
    "trainer = pl.Trainer(limit_test_batches=10)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y4LK0g65RrBm"
   },
   "source": [
    "For example, some metrics need to be computed on the entire validation results, such as AUC ROC. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8MmeRs2DR3dD"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(limit_val_batches=10)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xmigcNa1A2Vy"
   },
   "source": [
    "You can use a float to limit the batches be percentage of the set on every epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "W7uGJt8nA4tv"
   },
   "outputs": [],
   "source": [
    "# run through only 25% of the test set each epoch\n",
    "trainer = pl.Trainer(limit_test_batches=0.25)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YRI8THtUN7_e"
   },
   "source": [
    "# Training on GPUs\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R8FFkX_FwlfE"
   },
   "source": [
    "To run on 1 GPU set the flag to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Nnzkf3KaOE27"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cxBg47s5PB1P"
   },
   "source": [
    "to run on 2 or 4 GPUs, set the flag to 2 or 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cSEM4ihLrohT"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=2)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZE6ZgwtNudro"
   },
   "source": [
    "You can also select which GPU devices to run on, using a list of indices like [1, 4] \n",
    "\n",
    "or a string containing a comma separated list of GPU ids like '1,2'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gQkJtq0urrjq"
   },
   "outputs": [],
   "source": [
    "# list: train on GPUs 1, 4 (by bus ordering)\n",
    "# trainer = Trainer(gpus='1, 4') # equivalent\n",
    "trainer = pl.Trainer(gpus=[1, 4])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XghDPad4us74"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=list(range(4)))\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6FVkKHpSPMTW"
   },
   "source": [
    "You can use all the GPUs you have available by setting `gpus=-1`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "r6cKQijYrtPe"
   },
   "outputs": [],
   "source": [
    "# trainer = Trainer(gpus='-1') - equivalent\n",
    "trainer = pl.Trainer(gpus=-1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2C-fNLm3UGCV"
   },
   "source": [
    "Lightning uses the PCI bus_id as the index for ordering GPUs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_V75s7EhOFhE"
   },
   "source": [
    "### `auto_select_gpus`\n",
    "\n",
    "You can save on GPUs by running in “exclusive mode”, meaning only one process at a time can access them. If your not sure which GPUs you should use when running exclusive mode, Lightning can automatically find unoccupied GPUs for you. \n",
    "\n",
    "Simply specify the number of gpus as an integer `gpus=k`, and set the trainer flag `auto_select_gpus=True`. Lightning will automatically help you find k gpus that are not occupied by other processes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_Sd3XFsAOIwd"
   },
   "outputs": [],
   "source": [
    "# enable auto selection (will find two available gpus on system)\n",
    "trainer = pl.Trainer(gpus=2, auto_select_gpus=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a5JGSBMQhJNp"
   },
   "source": [
    "## analyzing GPU usage\n",
    "\n",
    "### log_gpu_memory\n",
    "\n",
    "This is useful to analyze the memory usage of your GPUs.\n",
    "\n",
    "To get the GPU memory usage for every GPU on the master node, set the flag to log_gpu_memory=all.\n",
    "\n",
    "Under the hood, lightning uses the nvidia-smi command which may slow your training down.\n",
    "\n",
    "Your logs can become overwhelmed if you log the usage from many GPUs at once. In this case, you can also set the flag to min_max which will log only the min and max usage across all the GPUs of the master node.\n",
    "\n",
    "Note that lightning is not logging the usage across all nodes for performance reasons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "idus3ZGahOki"
   },
   "outputs": [],
   "source": [
    "# log all the GPUs (on master node only)\n",
    "trainer = Trainer(log_gpu_memory='all')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-mevgiy_hkip"
   },
   "source": [
    "To avoid the performance decrease you can also set `log_gpu_memory=min_max` to only log the min and max memory on the master node.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SlvLJnWyhs7J"
   },
   "outputs": [],
   "source": [
    "# log only the min and max memory on the master node\n",
    "trainer = Trainer(log_gpu_memory='min_max')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K82FLLIJVQG3"
   },
   "source": [
    "\n",
    "But what if you want to train on multiple machines and not just one?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YViQ6PXesAue"
   },
   "source": [
    "# Training on multiple GPUs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WacbBQUivxQq"
   },
   "source": [
    "Lightning makes your models hardware agnostic, and you can run on GPUs with a flip of a flag. Lightning also supports training on multiple GPUs across many machines.\n",
    "\n",
    "You can do this by setting the num_nodes flag.\n",
    "\n",
    "The world size, or the total number of GPUs you are using, will be gpus*num_nodes.\n",
    "\n",
    "If i set gpus=8 and num_nodes=32 then I will be training on 256 GPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5iKckmDvr8zZ"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=8, num_nodes=32)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GgcSbDjjlSTh"
   },
   "source": [
    "## Accelerators\n",
    "\n",
    "Under the hood, Lightning uses distributed data parallel (or DDP) by default to distribute training across GPUs.\n",
    "\n",
    "This Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables.\n",
    "\n",
    "Under the hood it's as if you had called your script like this:\n",
    "\n",
    "1. Each GPU across each node gets its own process.\n",
    "2. Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset.\n",
    "3. Each process inits the model. (Make sure to set the random seed so that each model initializes with the same weights.)\n",
    "4. Each process performs a full forward and backward pass in parallel.\n",
    "5. The gradients are synced and averaged across all processes.\n",
    "6. Each process updates its optimizer.\n",
    "If you request multiple GPUs or nodes without setting a mode, DDP will be automatically used.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "n_Brr7F5wdtj"
   },
   "outputs": [],
   "source": [
    "# ddp = DistributedDataParallel\n",
    "# trainer = pl.Trainer(gpus=2, num_nodes=2) equivalent\n",
    "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "edxHyttC5J3e"
   },
   "source": [
    "DDP is the fastest and recommended way to distribute your training, but you can pass in other backends to `accelerator` trainer flag, when DDP is not supported.\n",
    "\n",
    "DDP isn't available in\n",
    "* Jupyter Notebook, Google COLAB, Kaggle, etc.\n",
    "* If You have a nested script without a root package\n",
    "* or if Your script needs to invoke .fit or .test multiple times"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZDh96mavxHxf"
   },
   "source": [
    "### DDP_SPAWN\n",
    "\n",
    "In these cases, you can use `ddp_spawn` instead. `ddp_spawn` is exactly like DDP except that it uses `.spawn()` to start the training processes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JM5TKtgLxo37"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp_spawn')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sebhVE3qrhKK"
   },
   "source": [
    "We STRONGLY discourage this use because it has limitations (due to Python and PyTorch):\n",
    "\n",
    "* Since .spawn() trains the model in subprocesses, the model on the main process does not get updated.\n",
    "\n",
    "* Dataloader(num_workers=N), where N is large, bottlenecks training with DDP… ie: it will be VERY slow or won’t work at all. This is a PyTorch limitation.\n",
    "\n",
    "* Forces everything to be picklable.\n",
    "\n",
    "DDP is MUCH faster than DDP_spawn. To be able to use DDP we recommend you: \n",
    "\n",
    "1. Install a top-level module for your project using setup.py\n",
    "\n",
    "```\n",
    "# setup.py\n",
    "#!/usr/bin/env python\n",
    "\n",
    "from setuptools import setup, find_packages\n",
    "\n",
    "setup(name='src',\n",
    "      version='0.0.1',\n",
    "      description='Describe Your Cool Project',\n",
    "      author='',\n",
    "      author_email='',\n",
    "      url='https://github.com/YourSeed',  # REPLACE WITH YOUR OWN GITHUB PROJECT LINK\n",
    "      install_requires=[\n",
    "            'pytorch-lightning'\n",
    "      ],\n",
    "      packages=find_packages()\n",
    "      )\n",
    "\n",
    "```\n",
    "\n",
    "2. Setup your project like so:\n",
    "\n",
    "```\n",
    "/project\n",
    "    /src\n",
    "        some_file.py\n",
    "        /or_a_folder\n",
    "    setup.py\n",
    "```\n",
    "3. Install as a root-level package\n",
    "```\n",
    "cd /project\n",
    "pip install -e .\n",
    "```\n",
    "4. You can then call your scripts anywhere\n",
    "```\n",
    "cd /project/src\n",
    "\n",
    "python some_file.py --accelerator 'ddp' --gpus 8\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cmB3I_oyw7a8"
   },
   "source": [
    "### DP\n",
    "\n",
    "If you're using windows, DDP is not supported. You can use `dp` for DataParallel instead: DataParallel uses multithreading, instead of multiprocessing. It splits a batch across k GPUs. That is, if you have a batch of 32 and use DP with 2 gpus, each GPU will process 16 samples, after which the root node will aggregate the results.\n",
    "\n",
    "DP use is discouraged by PyTorch and Lightning. Use DDP which is more stable and at least 3x faster.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OO-J0ISvlVCg"
   },
   "outputs": [],
   "source": [
    "# dp = DataParallel\n",
    "trainer = pl.Trainer(gpus=2, accelerator='dp')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y7E2eHZKwUn9"
   },
   "source": [
    "### DDP2\n",
    "\n",
    "In certain cases, it’s advantageous to use ***all*** batches on the same machine, instead of a subset. For instance, in self-supervised learning, a common performance boost comes from increasing the number of negative samples.\n",
    "\n",
    "In this case, we can use DDP2 which behaves like DP in a machine and DDP across nodes. DDP2 does the following:\n",
    "\n",
    "* Copies a subset of the data to each node.\n",
    "* Inits a model on each node.\n",
    "* Runs a forward and backward pass using DP.\n",
    "* Syncs gradients across nodes.\n",
    "* Applies the optimizer updates.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y4xweqL3xHER"
   },
   "outputs": [],
   "source": [
    "# ddp2 = DistributedDataParallel + dp\n",
    "trainer = pl.Trainer(gpus=2, num_nodes=2, accelerator='ddp2')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lhKNCnveeeq5"
   },
   "source": [
    "- The second mode is ddp_spawn. This works like ddp, but instead of calling your script multiple times, lightning will use multiprocessing spawn to start a subprocess per GPU. \n",
    "\n",
    "However, you should be careful of mixing this mode with num_workers > 0 in your dataloaders because it will bottleneck your training. This is a current known limitation of PyTorch which is why we recommend using our ddp implementation instead.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HUf9ANyQkFFO"
   },
   "source": [
    "\n",
    "### mocking ddp\n",
    "\n",
    "Testing or debugging DDP can be hard, so we have a distributed backend that simulates ddp on cpus to make it easier. Set `num_processes` to a number greater than 1 when using accelerator=\"ddp_cpu\" to mimic distributed training on a machine without GPUs. Note that while this is useful for debugging, it will not provide any speedup, since single-process Torch already makes efficient use of multiple CPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZSal5Da9kHOf"
   },
   "outputs": [],
   "source": [
    "# Simulate DDP for debugging on your GPU-less laptop\n",
    "trainer = Trainer(accelerator=\"ddp_cpu\", num_processes=2)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Br_btCy5lgES"
   },
   "source": [
    "# Training on TPUS\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DXkBNITdv44d"
   },
   "source": [
    "Another option for accelerating your training is using TPUs.\n",
    "A TPU is a Tensor processing unit, designed specifically for deep learning. Each TPU has 8 cores where each core is optimized for 128x128 matrix multiplies. Google estimates that 8 TPU cores are about as fast as 4 V100 GPUs!\n",
    "\n",
    "A TPU pod hosts many TPUs on it. Currently, TPU pod v2 has 2048 cores! You can request a full pod from Google cloud or a “slice” which gives you some subset of those 2048 cores.\n",
    "\n",
    "At this moment, TPUs are available on Google Cloud (GCP), Google Colab and Kaggle Environments.\n",
    "\n",
    "Lightning supports training on TPUs without any code adjustments to your model. Just like when using GPUs, Lightning automatically inserts the correct samplers - no need to do this yourself!\n",
    "\n",
    "Under the hood, lightning uses the XLA framework developed jointly by the facebook and google XLA teams. And we want to recognize their efforts in advancing TPU adoption of PyTorch.\n",
    "\n",
    "## tpu_cores\n",
    "To train on TPUs, set the tpu_cores flag.\n",
    "\n",
    "When using colab or kaggle, the allowed values are 1 or 8 cores. When using google cloud, any value above 8 is allowed.\n",
    "\n",
    "Your effective batch size is the batch size passed into a dataloader times the total number of tpu cores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "itP9y70gmD9M"
   },
   "outputs": [],
   "source": [
    "# int: train on a single core\n",
    "trainer = pl.Trainer(tpu_cores=1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NJKnzPb3mKEg"
   },
   "outputs": [],
   "source": [
    "# int: train on all cores few cores\n",
    "trainer = pl.Trainer(tpu_cores=8)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8a4exfWUmOHq"
   },
   "source": [
    "You can also choose which TPU core to train on, by passing a list [1-8]. This is not an officially supported use case but we are working with the XLA team to improve this user experience.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "S6OrjE_bmT-_"
   },
   "outputs": [],
   "source": [
    "# list: train on a single selected core\n",
    "trainer = pl.Trainer(tpu_cores=[2])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Afqx3sFUmfWD"
   },
   "source": [
    "To train on more than 8 cores (ie: a POD), submit this script using the xla_dist script.\n",
    "\n",
    "\n",
    "\n",
    "```\n",
    "python -m torch_xla.distributed.xla_dist\n",
    "--tpu=$TPU_POD_NAME\n",
    "--conda-env=torch-xla-nightly\n",
    "--env=XLA_USE_BF16=1\n",
    "-- python your_trainer_file.py\n",
    "```\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ncPvbUVQqKOh"
   },
   "source": [
    "# Advanced distributed training\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4MP7bEgnv7qK"
   },
   "source": [
    "\n",
    "Lightning supports distributed training across multiple GPUs and TPUs out of the box by setting trainer flags, but it also allows you to control the way sampling is done if you need to."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wdHiTfAMepKH"
   },
   "source": [
    "## replace_sampler_ddp\n",
    "In PyTorch, you must use torch.nn.DistributedSampler for multi-node or GPU training. The sampler makes sure each GPU sees the appropriate part of your data.\n",
    "\n",
    "```\n",
    "# without lightning\n",
    "def train_dataloader(self):\n",
    "    dataset = MNIST(...)\n",
    "    sampler = None\n",
    "\n",
    "    if self.on_tpu:\n",
    "        sampler = DistributedSampler(dataset)\n",
    "\n",
    "    return DataLoader(dataset, sampler=sampler)\n",
    "```\n",
    "Lightning adds the correct samplers when needed, so no need to explicitly add samplers. By default it will add `shuffle=True` for train sampler and `shuffle=False` for val/test sampler.\n",
    "\n",
    "If you want to customize this behaviour, you can set `replace_sampler_ddp=False` and add your own distributed sampler.\n",
    "\n",
    "(note: For iterable datasets, we don’t do this automatically.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZfmcB_e_7HbE"
   },
   "outputs": [],
   "source": [
    "sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)\n",
    "dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)\n",
    "\n",
    "trainer = pl.Trainer(gpus=2, num_nodes=2, replace_sampler_ddp=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-IOhk1n0lL3_"
   },
   "source": [
    "## prepare_data_per_node\n",
    "\n",
    "When doing multi NODE training, if your nodes share the same file system, then you don't want to download data more than once to avoid possible collisions. \n",
    "\n",
    "Lightning automatically calls the prepare_data hook on the root GPU of the master node (ie: only a single GPU).\n",
    "\n",
    "In some cases where your nodes don't share the same file system, you need to download the data on each node. In this case you can set this flag to true and lightning will download the data on the root GPU of each node.\n",
    "\n",
    "This flag is defaulted to True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WFBMUR48lM04"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=2, num_nodes=2, prepare_data_per_node=False)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FKBwXqo4q-Vp"
   },
   "source": [
    "## sync_batchnorm\n",
    "\n",
    "Batch norm is computed per GPU/TPU. This flag enables synchronization between batchnorm layers across all GPUs.\n",
    "It is recommended if you have small batch sizes.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GhaCLTEZrAQi"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(gpus=4, sync_batchnorm=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XuFA7VTFMY9-"
   },
   "source": [
    "# Debugging flags\n",
    "\n",
    "Lightning offers a couple of flags to make debugging your models easier:\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AKoS3fdml4Jx"
   },
   "source": [
    "## Fast Dev Run\n",
    "\n",
    "To help you save time debugging, your first run should use the fast_dev_run flag.\n",
    "\n",
    "This won't generate logs or save checkpoints but will touch every line of your code to make sure that it is working as intended.\n",
    "\n",
    "Think about this flag like a compiler. You make changes to your code, and run Trainer with this flag to verify that your changes are bug free.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L5vuG7GSmhzK"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(fast_dev_run=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HRP1qQR5nT4p"
   },
   "source": [
    "## overfit_batches\n",
    "\n",
    "Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it.\n",
    "\n",
    "Useful for quickly debugging or trying to overfit on purpose."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NTM-dqGMnXms"
   },
   "outputs": [],
   "source": [
    "# use only 1% of the train set (and use the train set for val and test)\n",
    "trainer = pl.Trainer(overfit_batches=0.01)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "c0LV0gC3nl1X"
   },
   "outputs": [],
   "source": [
    "# overfit on 10 of the same batches\n",
    "trainer = pl.Trainer(overfit_batches=10)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lt3UHU6WgtS_"
   },
   "source": [
    "Or a float to represent percentage of data to run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K3yUqADhgnkf"
   },
   "outputs": [],
   "source": [
    "# run through only 25% of the test set each epoch\n",
    "trainer = pl.Trainer(limit_test_batches=0.25)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ODN66NeVg_2o"
   },
   "source": [
    "In the case of multiple test dataloaders, the limit applies to each dataloader individually.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8aQx5SLeMz1R"
   },
   "source": [
    "# accumulate_grad_batches\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g8GczZXFwKC7"
   },
   "source": [
    "The batch size controls the accuracy of the estimate of the gradients. Small batch size use less memory, but decrease accuracy. When training large models, such as NLP transformers, it is useful to accumulate gradients before calling backwards(). It allows for bigger batch sizes than what can actually fit on a GPU/TPU in a single step.\n",
    "\n",
    "Use accumulate_grad_batches to accumulate gradients every k batches or as set up in the dict. Trainer also calls optimizer.step() for the last indivisible step number.\n",
    "\n",
    "For example, set accumulate_grad_batches to 4 to accumulate every 4 batches. In this case the effective batch size is batch_size*4, so if your batch size is 32, effectively it will be 128."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2jB6-Z_yPhhf"
   },
   "outputs": [],
   "source": [
    "# accumulate every 4 batches (effective batch size is batch*4)\n",
    "trainer = pl.Trainer(accumulate_grad_batches=4)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_Yi-bdTOgINC"
   },
   "source": [
    "You can also pass a dictionary to specify different accumulation per epoch. We can set it to `{5: 3, 10: 20}` to have no accumulation for epochs 1 to 4, accumulate 3 batches for epoch 5 to 10, and 20 batches after that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "X3xsoZ3YPgBv"
   },
   "outputs": [],
   "source": [
    "# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that\n",
    "trainer = pl.Trainer(accumulate_grad_batches={5: 3, 10: 20})\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "myzH8mV4M1_9"
   },
   "source": [
    "# 16 bit precision\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v9EaFAonwOk6"
   },
   "source": [
    "Most deep learning frameworks like PyTorch, train with 32-bit floating point arithmetic. \n",
    "\n",
    "But many models can still achieve full accuracy using half the precision.\n",
    "\n",
    "In 2017, NVIDIA researchers successfully used a combination of 32 and 16 bit precision (also known as mixed precision) and achieved the same accuracy as 32 bit precision training.\n",
    "\n",
    "The main two advantages are:\n",
    "\n",
    "- a reduction in memory requirements which enables larger batch sizes and models.\n",
    "- and a speed up in compute. On ampere, turing and volta architectures 16 bit precision models can train at least 3 times faster.\n",
    "\n",
    "As of PyTorch 1.6, NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, torch.cuda.amp. \n",
    "\n",
    "This package supersedes the apex package developed by NVIDIA."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TjNypZPHnxvJ"
   },
   "source": [
    "## precision\n",
    "\n",
    "Use precision flag to switch between full precision (32) to half precision (16). Can be used on CPU, GPU or TPUs.\n",
    "\n",
    "When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.\n",
    "\n",
    "If used on TPU will use torch.bfloat16 but tensor printing will still show torch.float32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kBZKMVx1nw-D"
   },
   "outputs": [],
   "source": [
    "# 16-bit precision\n",
    "trainer = pl.Trainer(gpus=1, precision=16)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VJGj3Jh7oQXU"
   },
   "source": [
    "In earlier version of Lightning, we use NVIDIA Apex for 16-bit precision. Apex was the first library to attempt 16-bit and the automatic mixed precision library (amp), has since been merged into core PyTorch as of 1.6.\n",
    "\n",
    "If you insist in using Apex, you can set the amp_backend flag to 'apex' and install Apex on your own."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BDV1trAUPc9h"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HK5c_aVfNV4e"
   },
   "source": [
    "## amp_level\n",
    "Apex includes 4 optimization levels:\n",
    "O0 (FP32 training)\n",
    "O1 (Conservative Mixed Precision): only some whitelist ops are done in FP16.\n",
    "O2 (Fast Mixed Precision): this is the standard mixed precision training. It maintains FP32 master weights and optimizer.step acts directly on the FP32 master weights.\n",
    "O3 (FP16 training): full FP16. Passing keep_batchnorm_fp32=True can speed things up as cudnn batchnorm is faster anyway.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FshMFPowNbWt"
   },
   "outputs": [],
   "source": [
    "# default used by the Trainer\n",
    "trainer = pl.Trainer(gpus=1, precision=16, amp_backend='apex', amp_level='O2')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y8KEr1YvNgkC"
   },
   "source": [
    "# `auto_scale_batch_size`\n",
    "\n",
    " \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7F1pKFIuwSFl"
   },
   "source": [
    "Lightning can help you improve your model by using auto_scale_batch_size flag, which tries to find the largest batch size that fits into memory, before you start your training.\n",
    "Larger batch size often yields better estimates of gradients, but may also result in longer training time. \n",
    "\n",
    "Set it to True to initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9_jE-iyyheIv"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(auto_scale_batch_size=True)\n",
    "\n",
    "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yaHsJvwFhNJt"
   },
   "source": [
    "You can set the value to `power`. `power` scaling starts from a batch size of 1 and keeps doubling the batch size until an out-of-memory (OOM) error is encountered.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qx0FbQrphgw1"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(auto_scale_batch_size='power')\n",
    "\n",
    "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8bwgVF9zhZ75"
   },
   "source": [
    "You can also set it to `binsearch`, that continues to finetune the batch size by performing a binary search.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QObXNs3yNrg9"
   },
   "outputs": [],
   "source": [
    "# run batch size scaling, result overrides hparams.batch_size\n",
    "trainer = pl.Trainer(auto_scale_batch_size='binsearch')\n",
    "\n",
    "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5OWdhSsZjqW7"
   },
   "source": [
    "This feature expects that a batch_size field in the hparams of your model, i.e., model.hparams.batch_size should exist and will be overridden by the results of this algorithm. \n",
    "\n",
    "Additionally, your train_dataloader() method should depend on this field for this feature to work.\n",
    "\n",
    "The algorithm in short works by:\n",
    "1. Dumping the current state of the model and trainer\n",
    "\n",
    "2. Iteratively until convergence or maximum number of tries max_trials (default 25) has been reached:\n",
    "* Call fit() method of trainer. This evaluates steps_per_trial (default 3) number of training steps. Each training step can trigger an OOM error if the tensors (training batch, weights, gradients etc.) allocated during the steps have a too large memory footprint.\n",
    "  * If an OOM error is encountered, decrease the batch size\n",
    "  * Else increase it.\n",
    "* How much the batch size is increased/decreased is determined by the chosen strategy.\n",
    "\n",
    "3. The found batch size is saved to model.hparams.batch_size\n",
    "\n",
    "4. Restore the initial state of model and trainer\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q4CvxfZmOWBd"
   },
   "source": [
    "# `auto_lr_find`\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j85e8usNwdBV"
   },
   "source": [
    "Selecting a good learning rate for your deep learning training is essential for both better performance and faster convergence.\n",
    "\n",
    "Even optimizers such as Adam that are self-adjusting the learning rate can benefit from more optimal choices.\n",
    "\n",
    "To reduce the amount of guesswork concerning choosing a good initial learning rate, you can use Lightning auto learning rate finder.\n",
    "\n",
    "The learning rate finder does a small run where the learning rate is increased after each processed batch and the corresponding loss is logged. The result of this is a lr vs. loss plot that can be used as guidance for choosing an optimal initial lr.\n",
    "\n",
    "\n",
    "warning: For the moment, this feature only works with models having a single optimizer. LR support for DDP is not implemented yet, it is coming soon.\n",
    "\n",
    "\n",
    "***auto_lr_find=***\n",
    "\n",
    "In the most basic use case, this feature can be enabled during trainer construction with Trainer(auto_lr_find=True).\n",
    "When .fit(model) is called, the LR finder will automatically run before any training is done. The lr that is found and used will be written to the console and logged together with all other hyperparameters of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iuhve9RBOfFh"
   },
   "outputs": [],
   "source": [
    "# default used by the Trainer (no learning rate finder)\n",
    "trainer = pl.Trainer(mnist_model, auto_lr_find=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BL-gjXNCPDXk"
   },
   "source": [
    "This flag sets your learning rate which can be accessed via self.lr or self.learning_rate.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wEb-vIMmPJQf"
   },
   "outputs": [],
   "source": [
    "class LitModel(LightningModule):\n",
    "\n",
    "    def __init__(self, learning_rate):\n",
    "        self.learning_rate = learning_rate\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return Adam(self.parameters(), lr=(self.lr or self.learning_rate))\n",
    "\n",
    "# finds learning rate automatically\n",
    "# sets hparams.lr or hparams.learning_rate to that learning rate\n",
    "trainer = pl.Trainer(mnist_model, auto_lr_find=True)\n",
    "\n",
    "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RweqvpnVPPSh"
   },
   "source": [
    "To use an arbitrary value set it as auto_lr_find\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4LKI39IfPLJv"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(mnist_model, auto_lr_find='my_value')\n",
    "\n",
    "trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9VAhPRKbPX-m"
   },
   "source": [
    "Under the hood, when you call tune it runs the learning rate finder.\n",
    "\n",
    "If you want to inspect the results of the learning rate finder before doing any actual training or just play around with the parameters of the algorithm, this can be done by invoking the lr_find method of the trainer. A typical example of this would look like\n",
    "\n",
    "\n",
    "```\n",
    "trainer = pl.Trainer(auto_lr_find=True)\n",
    "\n",
    "# Run learning rate finder\n",
    "lr_finder = trainer.lr_find(model)\n",
    "\n",
    "# Results can be found in\n",
    "lr_finder.results\n",
    "\n",
    "# Plot with\n",
    "fig = lr_finder.plot(suggest=True)\n",
    "fig.show()\n",
    "\n",
    "# Pick point based on plot, or get suggestion\n",
    "new_lr = lr_finder.suggestion()\n",
    "\n",
    "# update hparams of the model\n",
    "model.hparams.lr = new_lr\n",
    "\n",
    "# Fit model\n",
    "trainer.fit(model)\n",
    "```\n",
    "\n",
    "The figure produced by lr_finder.plot() should look something like the figure below. It is recommended to not pick the learning rate that achieves the lowest loss, but instead something in the middle of the sharpest downward slope (red point). This is the point returned py lr_finder.suggestion().\n",
    "\n",
    "![image.png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tn1RV-jfOjt1"
   },
   "source": [
    "# `benchmark`\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rsmTl5zfwjM3"
   },
   "source": [
    "You can try to speed your system by setting `benchmark=True`, which enables cudnn.benchmark. This flag is likely to increase the speed of your system if your input sizes don’t change. This flag makes cudnn auto-tuner look for the optimal set of algorithms for the given hardware configuration. This usually leads to faster runtime.\n",
    "But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears, possibly leading to worse runtime performances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dWr-OCBgQCeb"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=1, benchmark=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qwAvSKYGa24K"
   },
   "source": [
    "# `deterministic`\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tl5mfmafwmat"
   },
   "source": [
    "PyTorch does not guarantee reproducible results, even when using identical seeds. To guarentee reproducible results, you can remove most of the randomness from your process by setting the `deterministic` flag to True.\n",
    "\n",
    "Note that it might make your system slower."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Mhv5LZ3HbNCK"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gpus=1, deterministic=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "u_5eJSvTf60f"
   },
   "source": [
    "# Exploding and vanishing gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B6drjh4pq6Jv"
   },
   "source": [
    "## track_grad_norm\n",
    "\n",
    "You can debug your grad norm to identify exploding or vanishing gradients using the `track_grad_norm` flag.\n",
    "\n",
    "Set value to 2 to track the 2-norm. or p to any p-norm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2taHUir8rflR"
   },
   "outputs": [],
   "source": [
    "# track the 2-norm\n",
    "trainer = pl.Trainer(track_grad_norm=2)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3vHKxmruk62f"
   },
   "source": [
    "May be set to ‘inf’ infinity-norm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "g7TbD6SxlAjP"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(track_grad_norm='inf')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TcMlRe7ywpe6"
   },
   "source": [
    "## Gradient clipping\n",
    "\n",
    "\n",
    "Exploding gradients refer to the problem that the gradients get too large and overflow in training, making the model unstable. Gradient clipping will ‘clip’ the gradients or cap them to a Threshold value to prevent the gradients from getting too large. To avoid this, we can set `gradient_clip_val` (default is set to 0.0).\n",
    "\n",
    "[when to use it, what are relevant values]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jF9JwmbOgOWF"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(gradient_clip_val=0.1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ggb4MkkQrr1h"
   },
   "source": [
    "# truncated_bptt_steps\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "s1Iu6PyAw9_r"
   },
   "source": [
    "If you have a large recurrent model, you can use truncated_bptt_steps flag to split up the backprop over portions of the sequence. This flag will automatically truncate your batches and the trainer will apply Truncated Backprop to it.\n",
    "\n",
    "Make sure your batches have a sequence dimension.\n",
    "\n",
    "Lightning takes care of splitting your batch along the time-dimension.\n",
    "```\n",
    "# we use the second as the time dimension\n",
    "# (batch, time, ...)\n",
    "sub_batch = batch[0, 0:t, ...]\n",
    "Using this feature requires updating your LightningModule’s pytorch_lightning.core.LightningModule.training_step() to include a hiddens arg with the hidden\n",
    "\n",
    "# Truncated back-propagation through time\n",
    "def training_step(self, batch, batch_idx, hiddens):\n",
    "    # hiddens are the hiddens from the previous truncated backprop step\n",
    "    out, hiddens = self.lstm(data, hiddens)\n",
    "\n",
    "    return {\n",
    "        \"loss\": ...,\n",
    "        \"hiddens\": hiddens  # remember to detach() this\n",
    "    }\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WiTF1VMtruMU"
   },
   "outputs": [],
   "source": [
    "# backprop every 5 steps in a batch\n",
    "trainer = pl.Trainer(truncated_bptt_steps=5)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8XI_kEWkS-nT"
   },
   "source": [
    "To modify how the batch is split, override pytorch_lightning.core.LightningModule.tbptt_split_batch():\n",
    "\n",
    "```\n",
    "class LitMNIST(LightningModule):\n",
    "    def tbptt_split_batch(self, batch, split_size):\n",
    "        # do your own splitting on the batch\n",
    "        return splits\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oLbEmbmupwQ8"
   },
   "source": [
    "# reload_dataloaders_every_epoch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CLdNGVv9xD_L"
   },
   "source": [
    "Set to True to reload dataloaders every epoch (instead of loading just once in the beginning of training).\n",
    "\n",
    "```\n",
    "# if False (default)\n",
    "train_loader = model.train_dataloader()\n",
    "for epoch in epochs:\n",
    "    for batch in train_loader:\n",
    "        ...\n",
    "\n",
    "# if True\n",
    "for epoch in epochs:\n",
    "    train_loader = model.train_dataloader()\n",
    "    for batch in train_loader:\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "10AXthXxp311"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(reload_dataloaders_every_epoch=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f513EYl0bmmL"
   },
   "source": [
    "# Callbacks\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2pt7iGh4xNs5"
   },
   "source": [
    "\n",
    "Lightning Callbacks are self-contained programs that can be reused across projects.\n",
    "Callbacks should capture NON-ESSENTIAL logic that is NOT required for your LightningModule to run. Lightning includes some a few built-in callbacks that can be used with flags like early stopping and Model Checkpointing, but you can also create your own callbacks to add any functionality to your models.\n",
    "\n",
    "The callback API includes hooks that allow you to add logic at every point of your training:\n",
    "setup, teardown, on_epoch_start, on_epoch_end, on_batch_start, on_batch_end, on_init_start, on_keyboard_interrupt etc. \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1t84gvDNsUuh"
   },
   "source": [
    "## callbacks\n",
    "\n",
    "Use **callbacks=** to pass a list of user defined callbacks. These callbacks DO NOT replace the built-in callbacks (loggers or EarlyStopping). \n",
    "\n",
    "In this example, we create a dummy callback that prints a message when training starts and ends, using on_train_start and on_train_end hooks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oIXZYabub3f0"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks import Callback\n",
    "\n",
    "class PrintCallback(Callback):\n",
    "    def on_train_start(self, trainer, pl_module):\n",
    "        print(\"Training is started!\")\n",
    "    def on_train_end(self, trainer, pl_module):\n",
    "        print(\"Training is done.\")\n",
    "\n",
    "# a list of callbacks\n",
    "callbacks = [PrintCallback()]\n",
    "trainer = pl.Trainer(callbacks=callbacks)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cNF74CLYfJJu"
   },
   "source": [
    "# Model checkpointing\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2blgquBrxLtS"
   },
   "source": [
    "Checkpoints capture the exact value of all parameters used by a model.\n",
    "\n",
    "Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.\n",
    "\n",
    "Lightning automates saving and loading checkpoints so you restore a training session, saving all the required parameters including: \n",
    "* 16-bit scaling factor (apex)\n",
    "* Current epoch\n",
    "* Global step\n",
    "* Model state_dict\n",
    "* State of all optimizers\n",
    "* State of all learningRate schedulers\n",
    "* State of all callbacks\n",
    "* The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)\n",
    "\n",
    "By default Lightning will save a checkpoint in the working directory, which will be updated every epoch.\n",
    "\n",
    "### Automatic saving\n",
    "By default Lightning will save a checkpoint in the end of the first epoch in the working directory, which will be updated every epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XGu0JULrg9l7"
   },
   "outputs": [],
   "source": [
    "# default used by the Trainer\n",
    "trainer = pl.Trainer(default_root_dir=os.getcwd())\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3s9OjkGuhq1W"
   },
   "source": [
    "To change the checkpoint path pass in **default_root_dir=**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DgdxkrIQhvfw"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(default_root_dir='/your/path/to/save/checkpoints')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qyvj_bkWrJiE"
   },
   "source": [
    "\n",
    "You can also have Lightning update your checkpoint based on a specific metric that you are logging (using self.log), by passing the key to `monitor=`. For example, if we want to save checkpoint based on the validation loss, logged as `val_loss`, you can pass:\n",
    "\n",
    "\n",
    "```\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    filepath=os.getcwd(),\n",
    "    save_top_k=1,\n",
    "    verbose=True,\n",
    "    monitor='val_loss',\n",
    "    mode='min',\n",
    "    prefix=''\n",
    ")\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YzYMivw1rO1O"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "\n",
    "trainer = pl.Trainer(callbacks=[ModelCheckpoint(monitor='val_loss')])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5hYs_FV8iDMn"
   },
   "source": [
    "You can modify the behavior of checkpointing by creating your own callback, and passing it to the trainer. \n",
    "You can control\n",
    "* filepath- where logs are saved\n",
    "* save_top_k- save k top models\n",
    "* verbose\n",
    "* monitor- the metric to monitor\n",
    "* mode\n",
    "* prefix\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Tb1K2VYDiNTu"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "\n",
    "# DEFAULTS used by the Trainer\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    filepath=os.getcwd(),\n",
    "    save_top_k=3,\n",
    "    verbose=True,\n",
    "    monitor='val_loss',\n",
    "    mode='min',\n",
    "    prefix='',\n",
    ")\n",
    "\n",
    "trainer = Trainer(callbacks=[checkpoint_callback])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YKhZ6xRojJcl"
   },
   "source": [
    "You can disable checkpointing it by passing\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Yt8zd2ZFjOXX"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(checkpoint_callback=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HcLy8asCjrj9"
   },
   "source": [
    "### Manual saving\n",
    "\n",
    "You can manually save checkpoints and restore your model from the checkpointed state.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kZSkMJf0jR4x"
   },
   "outputs": [],
   "source": [
    "trainer.fit(model)\n",
    "trainer.save_checkpoint(\"example.ckpt\")\n",
    "new_model = LitAutoEncoder.load_from_checkpoint(checkpoint_path=\"example.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X2d9cjVPj7CP"
   },
   "source": [
    "### Checkpoint Loading\n",
    "To load a model along with its weights, biases and module_arguments use following method:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BpAFfg5zkFmH"
   },
   "outputs": [],
   "source": [
    "model = LitAutoEncoder.load_from_checkpoint(PATH)\n",
    "\n",
    "print(model.learning_rate)\n",
    "# prints the learning_rate you used in this checkpoint\n",
    "\n",
    "model.eval()\n",
    "y_hat = model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jTQ3mxSJkhFN"
   },
   "source": [
    "But if you don’t want to use the values saved in the checkpoint, pass in your own here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IoMcOh9-kfUP"
   },
   "outputs": [],
   "source": [
    "class LitAutoEncoder(LightningModule):\n",
    "\n",
    "    def __init__(self, in_dim, out_dim):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "        self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ITPVY8mNknut"
   },
   "source": [
    "you can restore the model like this\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "H7XeRJzVkuY8"
   },
   "outputs": [],
   "source": [
    "# if you train and save the model like this it will use these values when loading\n",
    "# the weights. But you can overwrite this\n",
    "LitAutoEncoder(in_dim=32, out_dim=10)\n",
    "\n",
    "# uses in_dim=32, out_dim=10\n",
    "model = LitAutoEncoder.load_from_checkpoint(PATH)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "14WwGpnVk0a4"
   },
   "outputs": [],
   "source": [
    "# uses in_dim=128, out_dim=10\n",
    "model = LitAutoEncoder.load_from_checkpoint(PATH, in_dim=128, out_dim=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bY5s6wP_k1CU"
   },
   "source": [
    "\n",
    "\n",
    "## Restoring Training State (resume_from_checkpoint)\n",
    "If your training was cut short for some reason, you can resume exactly from where you left off using the `resume_from_checkpoint` flag, which will automatically restore model, epoch, step, LR schedulers, apex, etc..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9zfhHtyrk3rO"
   },
   "outputs": [],
   "source": [
    "model = LitAutoEncoder()\n",
    "trainer = pl.Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')\n",
    "\n",
    "# automatically restores model, epoch, step, LR schedulers, apex, etc...\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xkKdvALFsmT2"
   },
   "source": [
    "## weights_save_path\n",
    "You can specify a directory for saving weights file using `weights_save_path`.\n",
    "\n",
    "(If you are using a custom checkpoint callback, the checkpoint callback will override this flag)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9OwHHFcCsrgT"
   },
   "outputs": [],
   "source": [
    "# save to your custom path\n",
    "trainer = pl.Trainer(weights_save_path='my/path')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PbNtlJ9Wsscf"
   },
   "outputs": [],
   "source": [
    "# if checkpoint callback used, then overrides the weights path\n",
    "# **NOTE: this saves weights to some/path NOT my/path\n",
    "checkpoint = ModelCheckpoint(filepath='some/path')\n",
    "trainer = pl.Trainer(\n",
    "    callbacks=[checkpoint],\n",
    "    weights_save_path='my/path'\n",
    ")\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uDdxCuyHdWQt"
   },
   "source": [
    "# Early stopping\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fqAy3ihRxTfR"
   },
   "source": [
    "The EarlyStopping callback can be used to monitor a validation metric and stop the training when no improvement is observed, to help you avoid overfitting.\n",
    "\n",
    "To enable Early Stopping you can init the EarlyStopping callback, and pass it to `callbacks=` trainer flag. The callback will look for a logged metric to early stop on.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lFx976CheH93"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "\n",
    "trainer = pl.Trainer(callbacks=[EarlyStopping('val_loss')])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MwpJfTvjeOwF"
   },
   "source": [
    "You can customize the callback using the following params:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "V6I9h6HteK2U"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "\n",
    "early_stop_callback = EarlyStopping(\n",
    "   monitor='val_accuracy',\n",
    "   min_delta=0.00,\n",
    "   patience=3,\n",
    "   verbose=False,\n",
    "   mode='max'\n",
    ")\n",
    "trainer = pl.Trainer(callbacks=[early_stop_callback])\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7TAIerPYe_Q1"
   },
   "source": [
    "The EarlyStopping callback runs at the end of every validation check, which, under the default configuration, happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the Trainer, for example check_val_every_n_epoch and val_check_interval. It must be noted that the patience parameter counts the number of validation checks with no improvement, and not the number of training epochs. Therefore, with parameters check_val_every_n_epoch=10 and patience=3, the trainer will perform at least 40 training epochs before being stopped."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VoKrX2ENh9Fg"
   },
   "source": [
    "# Logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-CQTPKd7iKLm"
   },
   "source": [
    "Lightning has built in integration with various loggers such as TensorBoard, wandb, commet, etc.\n",
    "\n",
    "\n",
    "You can pass any metrics you want to log during training to `self.log`, such as loss or accuracy. Similarly, pass in to self.log any metric you want to log during validation step.\n",
    "\n",
    "These values will be passed in to the logger of your choise. simply pass in any supported logger to logger trainer flag.\n",
    "\n",
    "\n",
    "\n",
    "Use the as`logger=` trainer flag to pass in a Logger, or iterable collection of Loggers, for experiment tracking.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ty5VPS3AiS8L"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "\n",
    "# default logger used by trainer\n",
    "logger = TensorBoardLogger(\n",
    "    save_dir=os.getcwd(),\n",
    "    version=1,\n",
    "    name='lightning_logs'\n",
    ")\n",
    "trainer = pl.Trainer(logger=logger)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jc5oWNpoiuuc"
   },
   "source": [
    "Lightning supports the use of multiple loggers, just pass a list to the Trainer.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BlYwMRRyivp_"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger\n",
    "logger1 = TensorBoardLogger('tb_logs', name='my_model')\n",
    "logger2 = TestTubeLogger('tb_logs', name='my_model')\n",
    "trainer = pl.Trainer(logger=[logger1, logger2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a7EyspQPh7iQ"
   },
   "source": [
    "## flush_logs_every_n_steps\n",
    "\n",
    "Use this flag to determine when logging to disc should happen."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Em_XvsmyiBbk"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(flush_logs_every_n_steps=100)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_vDeKE98qsl1"
   },
   "source": [
    "## log_every_n_steps\n",
    "How often to add logging rows (does not write to disk)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HkqD7D_0w1Tt"
   },
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(log_every_n_steps=1000)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9uw0gfe422CT"
   },
   "source": [
    "# info logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dQXpt0aatDGo"
   },
   "source": [
    "### default_root_dir\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "\n",
    "Default path for logs and weights when no logger or pytorch_lightning.callbacks.ModelCheckpoint callback passed. On certain clusters you might want to separate where logs and checkpoints are stored. If you don’t then use this argument for convenience. Paths can be local paths or remote paths such as s3://bucket/path or ‘hdfs://path/’. Credentials will need to be set up to use remote filepaths."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CMmID2Bts5W3"
   },
   "source": [
    "## weights_summary\n",
    "Prints a summary of the weights when training begins. Default is set to `top`- print summary of top level modules.\n",
    "\n",
    "Options: ‘full’, ‘top’, None."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KTl6EdwDs6j2"
   },
   "outputs": [],
   "source": [
    "\n",
    "# print full summary of all modules and submodules\n",
    "trainer = pl.Trainer(weights_summary='full')\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R57cSLl9w9ma"
   },
   "outputs": [],
   "source": [
    "# don't print a summary\n",
    "trainer = Trainer(weights_summary=None)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bSc2hU5AotAP"
   },
   "source": [
    "# progress bar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GgvbyDsBxcH6"
   },
   "source": [
    "## process_position\n",
    "\n",
    "Orders the progress bar. Useful when running multiple trainers on the same node.\n",
    "\n",
    "(This argument is ignored if a custom callback is passed to callbacks)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6ekz8Es8owDn"
   },
   "outputs": [],
   "source": [
    "# default used by the Trainer\n",
    "trainer = pl.Trainer(process_position=0)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "itivQFgEphBU"
   },
   "source": [
    "## progress_bar_refresh_rate\n",
    "\n",
    "How often to refresh the progress bar (in steps). In notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates, so raise it to 50 or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GKe6eVxmplL5"
   },
   "outputs": [],
   "source": [
    "# default used by the Trainer\n",
    "trainer = pl.Trainer(progress_bar_refresh_rate=1)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8rDHJOJbxNtf"
   },
   "outputs": [],
   "source": [
    "# disable progress bar\n",
    "trainer = Trainer(progress_bar_refresh_rate=0)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NCNvYLwjpWne"
   },
   "source": [
    "# profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pRknrG_zpY6M"
   },
   "outputs": [],
   "source": [
    "# to profile standard training events\n",
    "trainer = pl.Trainer(profiler=True)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ji6aWpU73kMM"
   },
   "source": [
    "You can also use Lightning AdvancedProfiler if you want more detailed information about time spent in each function call recorded during a given action. The output is quite verbose and you should only use this if you want very detailed reports.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "layG55pt316C"
   },
   "outputs": [],
   "source": [
    "from pytorch_lightning.profiler import AdvancedProfiler\n",
    "\n",
    "trainer = Trainer(profiler=AdvancedProfiler())\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<code style=\"color:#792ee5;\">\n",
    "    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>\n",
    "</code>\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n",
    "\n",
    "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
    "\n",
    "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n",
    "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
    "\n",
    "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
    "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
    "\n",
    "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n",
    "\n",
    "### Contributions !\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
    "\n",
    "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* You can also contribute your own notebooks with useful examples !\n",
    "\n",
    "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
    "\n",
    "<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_static/images/logo.png?raw=true\" width=\"800\" height=\"200\" />"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "05-trainer-flags-overview.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
